有码有颜!你要的生成模型VQ-VAE来了!
设为星标,干货直达!
上一篇文章生成模型之PixelCNN介绍了基于自回归的生成模型,这篇文章将介绍DeepMind(和PixelCNN同一作)于2017年提出的一种基于离散隐变量(Discrete Latent variables)的生成模型:VQ-VAE。VQ-VAE相比VAE有两个重要的区别:首先VQ-VAE采用离散隐变量,而不是像VAE那样采用连续的隐变量;然后VQ-VAE需要单独训练一个基于自回归的模型如PixelCNN来学习先验(prior),而不是像VAE那样采用一个固定的先验(标准正态分布)。此外,VQ-VAE还是一个强大的无监督表征学习模型,它学习的离散编码具有很强的表征能力,最近比较火的文本转图像模型DALL-E也是基于VQ-VAE的,而且最近的一些基于masked image modeling的无监督学习方法如BEiT也用VQ-VAE得到的离散编码作为训练目标。这篇文章将讲解VQ-VAE的原理以及具体的代码实现。
VQ-VAE
一个VAE模型包括三个部分:后验分布(Posterior),先验分布(Prior),以及似然(Likelihood)。其中后验分布用encoder网络来学习,似然用decoder网络来学习,而先验分布采用参数固定的标准正态分布。在VAE学习过程中,后验分布往往假定是一个对角方差的多元正态分布,而隐变量是一个连续的随机变量。
VQ-VAE与VAE的最主要的区别是VQ-VAE采用离散隐变量,如下图所示,对于encoder的输出通过向量量化(vector quantisation,VQ)的方法来离散化。至于为啥采用离散编码,论文的slides给出了三个主要原因:
许多重要的事物都是离散的,如语言; 更容易对先验建模,VQ-VAE采用PixelCNN来学习先验,离散编码只需要简单地采用softmax多分类; 连续的表征往往被encoder/decoder内在地离散化;
实际上就是计算和每个embedding向量的欧式距离,然后选择距离最近的embedding向量作为量化值。这里得到的量化值将作为decoder的输入,所以VQ-VAE的整个过程变成:
原来的VAE的后验分布是多元高斯分布,但对于VQ-VAE,经过VQ之后,后验分布可以看成是一个多类分布(categorical distribution),而且其概率分布为one-hot类型:
此时,后验分布其实变成了一个确定的分布,因为确定的过程没有任何随机因素。如果我们定义先验分布为一个均匀的多类分布的话(每个类别的概率均为),此时就可以直接简单计算出后验分布和先验分布的KL散度:
此时KL散度就是一个常量,那么VQ-VAE的训练损失就剩下了一项重建误差。实际上VQ-VAE的训练过程就没有用到先验分布,所以后面我们需要单独训练一个先验模型来生成数据,这是VQ-VAE和VAE的第二个区别。VQ-VAE分成两个阶段来得到生成模型,可以避免VAE训练过程中容易出现的“posterior collapse”。
VQ-VAE还存在一个问题,那就是由于argmin操作不可导,所以重建误差的梯度就无法传导到encoder。论文采用straight-through estimator来解决这个问题,所谓straight-through estimator其实就是一种用来估计一些不可导函数梯度的方法,如下图所示,threshold function是不可导的,此时我们在计算梯度时,直接忽略它而采用上游得到的梯度,这个行为就认为threshold function是一个identity function一样。
这里的sg指的是stop gradient操作,这意味着这个L2损失只会更新embedding空间,而不会传导到encoder。这里,我们也可以采用另外一种方式:**指数移动平均(exponential moving averages,EMA)**来更新embedding空间。假定为一系列和embedding向量对应的encoder的输出,此时L2损失为:
此时embedding向量的最优值有解析解,即对所有的元素求平均值:然而,训练过程中无法直接这样更新,因为训练是基于mini-batch的,并不是训练数据的全部。类比BatchNorm,我们可以采用EMA来更新embedding:
这里共需要维护两套EMA参数,一是每个embedding向量的对应的元素数量,二是的求和值。每次forward时,我们根据当前mini-batch得到,然后执行EMA,而用除以即可得到当前的embedding向量。采用EMA这种更新方式往往比直接采用L2损失收敛速度更快,论文采用的decay值为0.99。
除此之外,论文还额外增加一个训练loss:commitment loss,这个主要是约束encoder的输出和embedding空间保持一致,以避免encoder的输出变动较大(从一个embedding向量转向另外一个)。commitment loss也比较简单,直接计算encoder的输出和对应的量化得到的embedding向量的L2误差:
注意这里的sg是作用在embedding向量上,这意味着这个约束只会影响encoder。
综上,VQ-VAE共包含三个部分的训练loss:reconstruction loss,VQ loss,commitment loss。
其中reconstruction loss作用在encoder和decoder上,VQ loss用来更新embedding空间(也可用EMA方式),而commitment loss用来约束encoder,这里的为权重系数,论文默认设置为0.25。
另外,在实际实验中,一张图像会采用个离散隐变量,这个和encoder得到的特征图大小有关。对于ImageNet数据集,采用32x32大小的中间特征图,所以;对于CIFAR10数据集,采用8x8大小的中间特征图,所以。对于一张图像,VQ loss和commitment loss取个离散编码的loss平均值。从自动编码器的角度来看,VQ-VAE实现了对图像的压缩,即将一张图像压缩成个离散编码。这里要说明的一点是,VQ-VAE适用于多种模态的数据,除了图像之外,还可以用于语音和视频的生成,这里只讨论图像。
训练好VQ-VAE后,还需要训练一个先验模型来完成数据生成,对于图像来说,可以采用PixelCNN模型,这里我们不再是学习生成原始的pixels,而是学习生成离散编码。首先,我们需要用已经训练好的VQ-VAE模型对训练图像推理,得到每张图像对应的离散编码;然后用一个PixelCNN来对离散编码进行建模,最后的预测层采用基于softmax的多分类,类别数为embedding空间的大小。那么,生成图像的过程就比较简单了,首先用训练好的PixelCNN模型来采样一个离散编码样本,然后送入VQ-VAE的encoder中,得到生成的图像。整个过程如下图所示:
VQ-VAE-2
VQ-VAE-2是DeepMind团队于2019年提出的VQ-VAE的升级版,相比VQ-VAE,VQ-VAE-2采用多尺度的层级结构,如下图所示,这里采用了两个尺度的特征来进行量化。采用多尺度的好处是可以用将图像的局部特征和全局特征来分别建模,比如这里的Bottom Level的特征用于提取局部信息,而Top Level的特征用于提出全局信息。而且采用层级结构将可以用来生成尺寸较大的图像。
VQ-VAE的代码实现
这里参考Keras vq_vae blog和官方代码来用PyTorch实现VQ-VAE,首先我们以MNIST数据集来实现VQ-VAE的标准版本(非EMA)。首先要实现的是VQ-VAE最核心的部分:向量量化VQ,这里我们也将训练loss的实现放在了类的forward中,不过区分train和eval模式:如果是train模式,除了返回量化后的特征外,还返回VQ loss+commitment loss;而对于eval模式只返回量化后的特征。
class VectorQuantizer(nn.Module):
"""
VQ-VAE layer: Input any tensor to be quantized.
Args:
embedding_dim (int): the dimensionality of the tensors in the
quantized space. Inputs to the modules must be in this format as well.
num_embeddings (int): the number of vectors in the quantized space.
commitment_cost (float): scalar which controls the weighting of the loss terms (see
equation 4 in the paper - this variable is Beta).
"""
def __init__(self, embedding_dim, num_embeddings, commitment_cost):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.commitment_cost = commitment_cost
# initialize embeddings
self.embeddings = nn.Embedding(self.num_embeddings, self.embedding_dim)
def forward(self, x):
# [B, C, H, W] -> [B, H, W, C]
x = x.permute(0, 2, 3, 1).contiguous()
# [B, H, W, C] -> [BHW, C]
flat_x = x.reshape(-1, self.embedding_dim)
encoding_indices = self.get_code_indices(flat_x)
quantized = self.quantize(encoding_indices)
quantized = quantized.view_as(x) # [B, H, W, C]
if not self.training:
quantized = quantized.permute(0, 3, 1, 2).contiguous()
return quantized
# embedding loss: move the embeddings towards the encoder's output
q_latent_loss = F.mse_loss(quantized, x.detach())
# commitment loss
e_latent_loss = F.mse_loss(x, quantized.detach())
loss = q_latent_loss + self.commitment_cost * e_latent_loss
# Straight Through Estimator
quantized = x + (quantized - x).detach()
quantized = quantized.permute(0, 3, 1, 2).contiguous()
return quantized, loss
def get_code_indices(self, flat_x):
# compute L2 distance
distances = (
torch.sum(flat_x ** 2, dim=1, keepdim=True) +
torch.sum(self.embeddings.weight ** 2, dim=1) -
2. * torch.matmul(flat_x, self.embeddings.weight.t())
) # [N, M]
encoding_indices = torch.argmin(distances, dim=1) # [N,]
return encoding_indices
def quantize(self, encoding_indices):
"""Returns embedding tensor for a batch of indices."""
return self.embeddings(encoding_indices)
对于encoder和decoder,我们采用对称的结构,其中decoder采用stride=2的反卷积来进行上采用:
class Encoder(nn.Module):
"""Encoder of VQ-VAE"""
def __init__(self, in_dim=3, latent_dim=16):
super().__init__()
self.in_dim = in_dim
self.latent_dim = latent_dim
self.convs = nn.Sequential(
nn.Conv2d(in_dim, 32, 3, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, latent_dim, 1),
)
def forward(self, x):
return self.convs(x)
class Decoder(nn.Module):
"""Decoder of VQ-VAE"""
def __init__(self, out_dim=1, latent_dim=16):
super().__init__()
self.out_dim = out_dim
self.latent_dim = latent_dim
self.convs = nn.Sequential(
nn.ConvTranspose2d(latent_dim, 64, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(32, out_dim, 3, padding=1),
)
def forward(self, x):
return self.convs(x)
有了VQ,encoder和decoder,就可以定义VQ-VAE模型了,这里也同样区分train和eval模式。此外,这里采用L2的重建误差,而且我们用训练数据的标准差来归一化这个误差。
class VQVAE(nn.Module):
"""VQ-VAE"""
def __init__(self, in_dim, embedding_dim, num_embeddings, data_variance,
commitment_cost=0.25):
super().__init__()
self.in_dim = in_dim
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.data_variance = data_variance
self.encoder = Encoder(in_dim, embedding_dim)
self.vq_layer = VectorQuantizer(embedding_dim, num_embeddings, commitment_cost)
self.decoder = Decoder(in_dim, embedding_dim)
def forward(self, x):
z = self.encoder(x)
if not self.training:
e = self.vq_layer(z)
x_recon = self.decoder(e)
return e, x_recon
e, e_q_loss = self.vq_layer(z)
x_recon = self.decoder(e)
recon_loss = F.mse_loss(x_recon, x) / self.data_variance
return e_q_loss + recon_loss
下图为训练的VQ-VAE的模型在测试集上的重建效果,其中上面一行为原图,而下面一行对重建的图,可以看到VQ-VAE基本能完美重建MNIST数据集。
# get encode_indices of training images
train_indices = []
for images, labels in train_loader:
images = images - 0.5 # normalize to [-0.5, 0.5]
images = images.cuda()
with torch.inference_mode():
z = model.encoder(images) # [B, C, H, W]
b, c, h, w = z.size()
# [B, C, H, W] -> [B, H, W, C]
z = z.permute(0, 2, 3, 1).contiguous()
# [B, H, W, C] -> [BHW, C]
flat_z = z.reshape(-1, c)
encoding_indices = model.vq_layer.get_code_indices(flat_z) # [BHW,]
encoding_indices = encoding_indices.reshape(b, h, w)
train_indices.append(encoding_indices.cpu())
这里我们采用GatedPixelCNN模型来学习先验,训练好先验模型后,可以先用先验模型随机采样得到离散编码,然后送入VQ-VAE的decoder得到生成的图像,下图为一些生成样例:
class ExponentialMovingAverage(nn.Module):
"""Maintains an exponential moving average for a value.
This module keeps track of a hidden exponential moving average that is
initialized as a vector of zeros which is then normalized to give the average.
This gives us a moving average which isn't biased towards either zero or the
initial value. Reference (https://arxiv.org/pdf/1412.6980.pdf)
Initially:
hidden_0 = 0
Then iteratively:
hidden_i = hidden_{i-1} - (hidden_{i-1} - value) * (1 - decay)
average_i = hidden_i / (1 - decay^i)
"""
def __init__(self, init_value, decay):
super().__init__()
self.decay = decay
self.counter = 0
self.register_buffer("hidden", torch.zeros_like(init_value))
def forward(self, value):
self.counter += 1
self.hidden.sub_((self.hidden - value) * (1 - self.decay))
average = self.hidden / (1 - self.decay ** self.counter)
return average
然后来实现基于EMA的VQ,这里维护了两个EMA参数,分别是每个embedding向量对应的encoder输出集合的特征数量以及特征之和,然后我们去掉原来的VQ loss而采用EMA来更新embeddings,注意这个过程要忽略梯度:
class VectorQuantizerEMA(nn.Module):
"""
VQ-VAE layer: Input any tensor to be quantized. Use EMA to update embeddings.
Args:
embedding_dim (int): the dimensionality of the tensors in the
quantized space. Inputs to the modules must be in this format as well.
num_embeddings (int): the number of vectors in the quantized space.
commitment_cost (float): scalar which controls the weighting of the loss terms (see
equation 4 in the paper - this variable is Beta).
decay (float): decay for the moving averages.
epsilon (float): small float constant to avoid numerical instability.
"""
def __init__(self, embedding_dim, num_embeddings, commitment_cost, decay,
epsilon=1e-5):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.commitment_cost = commitment_cost
self.epsilon = epsilon
# initialize embeddings as buffers
embeddings = torch.empty(self.num_embeddings, self.embedding_dim)
nn.init.xavier_uniform_(embeddings)
self.register_buffer("embeddings", embeddings)
self.ema_dw = ExponentialMovingAverage(self.embeddings, decay)
# also maintain ema_cluster_size, which record the size of each embedding
self.ema_cluster_size = ExponentialMovingAverage(torch.zeros((self.num_embeddings,)), decay)
def forward(self, x):
# [B, C, H, W] -> [B, H, W, C]
x = x.permute(0, 2, 3, 1).contiguous()
# [B, H, W, C] -> [BHW, C]
flat_x = x.reshape(-1, self.embedding_dim)
encoding_indices = self.get_code_indices(flat_x)
quantized = self.quantize(encoding_indices)
quantized = quantized.view_as(x) # [B, H, W, C]
if not self.training:
quantized = quantized.permute(0, 3, 1, 2).contiguous()
return quantized
# update embeddings with EMA
with torch.no_grad():
encodings = F.one_hot(encoding_indices, self.num_embeddings).float()
updated_ema_cluster_size = self.ema_cluster_size(torch.sum(encodings, dim=0))
n = torch.sum(updated_ema_cluster_size)
updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon) /
(n + self.num_embeddings * self.epsilon) * n)
dw = torch.matmul(encodings.t(), flat_x) # sum encoding vectors of each cluster
updated_ema_dw = self.ema_dw(dw)
normalised_updated_ema_w = (
updated_ema_dw / updated_ema_cluster_size.reshape(-1, 1))
self.embeddings.data = normalised_updated_ema_w
# commitment loss
e_latent_loss = F.mse_loss(x, quantized.detach())
loss = self.commitment_cost * e_latent_loss
# Straight Through Estimator
quantized = x + (quantized - x).detach()
quantized = quantized.permute(0, 3, 1, 2).contiguous()
return quantized, loss
def get_code_indices(self, flat_x):
# compute L2 distance
distances = (
torch.sum(flat_x ** 2, dim=1, keepdim=True) +
torch.sum(self.embeddings ** 2, dim=1) -
2. * torch.matmul(flat_x, self.embeddings.t())
) # [N, M]
encoding_indices = torch.argmin(distances, dim=1) # [N,]
return encoding_indices
def quantize(self, encoding_indices):
"""Returns embedding tensor for a batch of indices."""
return F.embedding(encoding_indices, self.embeddings)
这里用EMA版本的VQ-VAE在CIFAR10数据集训练之后在测试集上的重建效果:
小结
本文简单地介绍了VQ-VAE的原理以及具体的代码实现,相比VAE,VQ-VAE采用离散编码,这也使得VQ-VAE需要两个阶段来得到生成模型。最近OpenAI提出的文本转图像的生成模型DALL-E更让我们体会到了VQ-VAE的强大之处。
参考
Neural Discrete Representation Learning Generating Diverse High-Fidelity Images with VQ-VAE-2 https://github.com/deepmind/sonnet/blob/v2/examples/vqvae_example.ipynb https://keras.io/examples/generative/vq_vae/ https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py https://avdnoord.github.io/homepage/slides/SANE2017.pdf
推荐阅读
辅助模块加速收敛,精度大幅提升!移动端实时的NanoDet-Plus来了!
机器学习算法工程师
一个用心的公众号